{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import argparse\n",
    "import os\n",
    "import torch\n",
    "from collections import OrderedDict\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.utils import save_image\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from torchvision import datasets\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "import seaborn as sns\n",
    "import pylab as py\n",
    "import time\n",
    "from pyDOE import lhs\n",
    "import warnings\n",
    "sys.path.insert(0, '../../../Scripts/')\n",
    "from models_imperfect import Generator, Discriminator, Q_Net\n",
    "from pig import *\n",
    "# from ../Scripts/helper import *\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CUDA support \n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda:7')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 10000\n",
    "lambda_val = 1\n",
    "lambda_q = 0.5\n",
    "tr_frac = 0.4\n",
    "\n",
    "#architecture for the models\n",
    "d_hid_dim = 50 \n",
    "d_num_layer = 2\n",
    "\n",
    "g_hid_dim = 50\n",
    "g_num_layer = 4\n",
    "\n",
    "q_hid_dim = 50\n",
    "q_num_layer = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.loadtxt('../../../datasets/tossing_trajectories.txt')\n",
    "x = data[:, 0:6]\n",
    "labels = data[:, 6:]\n",
    "\n",
    "# training and test split\n",
    "n_obs = int(tr_frac * x.shape[0])\n",
    "train_x , train_y = x[:n_obs,:] , labels[:n_obs, :] \n",
    "test_x , test_y = x[n_obs:,:] , labels[n_obs:, :]\n",
    "\n",
    "noise_dim = 2\n",
    "data_dim = train_x.shape[1]\n",
    "out_dim = labels.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = Discriminator(in_dim = (data_dim + out_dim), out_dim = 1, hid_dim=d_hid_dim, num_layers=d_num_layer).to(device)\n",
    "G = Generator(in_dim = (noise_dim + data_dim), out_dim = out_dim, hid_dim=g_hid_dim, num_layers=g_num_layer).to(device)\n",
    "Q = Q_Net(in_dim=(out_dim + data_dim), out_dim=noise_dim, hid_dim=q_hid_dim, num_layers=q_num_layer).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "PIG = Tossing_PIG(train_x, train_y, test_x, test_y, G, D, Q, device, num_epochs, lambda_val, noise_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0/10000] [MSE loss: 3.962500] [G loss: 446.361282] [D loss: 5.473109] [Q loss: 4.020192] [Phy loss: 444.083237] [Adv G loss: 0.540978]\n",
      "[Epoch 100/10000] [MSE loss: 0.240060] [G loss: 17.013950] [D loss: 5.197514] [Q loss: 3.390023] [Phy loss: 14.638809] [Adv G loss: 0.605100]\n",
      "[Epoch 200/10000] [MSE loss: 0.088430] [G loss: 4.475567] [D loss: 5.285515] [Q loss: 3.870740] [Phy loss: 1.731760] [Adv G loss: 0.591088]\n",
      "[Epoch 300/10000] [MSE loss: 0.189821] [G loss: 4.535903] [D loss: 5.144568] [Q loss: 4.220296] [Phy loss: 1.773377] [Adv G loss: 1.004580]\n",
      "[Epoch 400/10000] [MSE loss: 0.100596] [G loss: 5.806224] [D loss: 4.837609] [Q loss: 4.285120] [Phy loss: 1.955797] [Adv G loss: 1.948232]\n",
      "[Epoch 500/10000] [MSE loss: 0.028258] [G loss: 6.232815] [D loss: 5.001040] [Q loss: 4.401441] [Phy loss: 1.881438] [Adv G loss: 2.160344]\n",
      "[Epoch 600/10000] [MSE loss: 0.049339] [G loss: 12.184770] [D loss: 3.384433] [Q loss: 4.154646] [Phy loss: 2.456382] [Adv G loss: 7.952081]\n",
      "[Epoch 700/10000] [MSE loss: 0.132035] [G loss: 13.735851] [D loss: 3.386804] [Q loss: 4.296238] [Phy loss: 2.398830] [Adv G loss: 9.385897]\n",
      "[Epoch 800/10000] [MSE loss: 0.033954] [G loss: 11.358492] [D loss: 3.326930] [Q loss: 4.313107] [Phy loss: 2.804109] [Adv G loss: 6.699196]\n",
      "[Epoch 900/10000] [MSE loss: 0.088026] [G loss: 7.551250] [D loss: 4.723400] [Q loss: 3.554083] [Phy loss: 2.758538] [Adv G loss: 2.917157]\n",
      "[Epoch 1000/10000] [MSE loss: 0.070228] [G loss: 12.150185] [D loss: 3.335288] [Q loss: 4.172791] [Phy loss: 2.811942] [Adv G loss: 7.254869]\n",
      "[Epoch 1100/10000] [MSE loss: 0.091149] [G loss: 20.935643] [D loss: 3.556434] [Q loss: 4.344420] [Phy loss: 3.761749] [Adv G loss: 15.248905]\n",
      "[Epoch 1200/10000] [MSE loss: 0.046068] [G loss: 22.078547] [D loss: 2.964614] [Q loss: 3.554410] [Phy loss: 3.495226] [Adv G loss: 16.539324]\n",
      "[Epoch 1300/10000] [MSE loss: 0.104856] [G loss: 14.612611] [D loss: 3.710269] [Q loss: 3.941014] [Phy loss: 2.857226] [Adv G loss: 9.918754]\n",
      "[Epoch 1400/10000] [MSE loss: 0.046394] [G loss: 18.650222] [D loss: 4.974720] [Q loss: 3.962880] [Phy loss: 3.541915] [Adv G loss: 13.260763]\n",
      "[Epoch 1500/10000] [MSE loss: 0.069489] [G loss: 14.862463] [D loss: 3.172724] [Q loss: 3.418728] [Phy loss: 2.990061] [Adv G loss: 9.722209]\n",
      "[Epoch 1600/10000] [MSE loss: 0.057857] [G loss: 19.851982] [D loss: 2.850026] [Q loss: 3.736558] [Phy loss: 2.942004] [Adv G loss: 14.689411]\n",
      "[Epoch 1700/10000] [MSE loss: 0.031350] [G loss: 13.080307] [D loss: 3.747521] [Q loss: 4.475796] [Phy loss: 2.957549] [Adv G loss: 8.199715]\n",
      "[Epoch 1800/10000] [MSE loss: 0.050715] [G loss: 18.201846] [D loss: 4.918853] [Q loss: 3.959160] [Phy loss: 3.472237] [Adv G loss: 12.763205]\n",
      "[Epoch 1900/10000] [MSE loss: 0.021590] [G loss: 19.811754] [D loss: 3.529107] [Q loss: 4.218115] [Phy loss: 3.332116] [Adv G loss: 14.573066]\n",
      "[Epoch 2000/10000] [MSE loss: 0.023935] [G loss: 16.003659] [D loss: 4.104660] [Q loss: 4.382065] [Phy loss: 2.791054] [Adv G loss: 11.406465]\n",
      "[Epoch 2100/10000] [MSE loss: 0.038183] [G loss: 14.632835] [D loss: 2.641730] [Q loss: 3.840240] [Phy loss: 2.873179] [Adv G loss: 9.718184]\n",
      "[Epoch 2200/10000] [MSE loss: 0.049163] [G loss: 15.127453] [D loss: 2.534636] [Q loss: 3.418193] [Phy loss: 2.735205] [Adv G loss: 10.292371]\n",
      "[Epoch 2300/10000] [MSE loss: 0.027098] [G loss: 17.765474] [D loss: 3.282060] [Q loss: 3.756736] [Phy loss: 2.992340] [Adv G loss: 12.919801]\n",
      "[Epoch 2400/10000] [MSE loss: 0.042595] [G loss: 16.892780] [D loss: 4.760192] [Q loss: 4.249917] [Phy loss: 2.952739] [Adv G loss: 12.123801]\n",
      "[Epoch 2500/10000] [MSE loss: 0.022180] [G loss: 13.389529] [D loss: 3.412488] [Q loss: 3.414953] [Phy loss: 2.788249] [Adv G loss: 8.283748]\n",
      "[Epoch 2600/10000] [MSE loss: 0.019684] [G loss: 13.141027] [D loss: 3.799151] [Q loss: 3.825731] [Phy loss: 2.774552] [Adv G loss: 8.367995]\n",
      "[Epoch 2700/10000] [MSE loss: 0.069738] [G loss: 28.012947] [D loss: 2.358599] [Q loss: 3.754153] [Phy loss: 2.692716] [Adv G loss: 23.359175]\n",
      "[Epoch 2800/10000] [MSE loss: 0.020127] [G loss: 15.049390] [D loss: 3.362330] [Q loss: 3.902502] [Phy loss: 3.187613] [Adv G loss: 10.070980]\n",
      "[Epoch 2900/10000] [MSE loss: 0.064689] [G loss: 25.310244] [D loss: 2.483940] [Q loss: 4.260566] [Phy loss: 2.816483] [Adv G loss: 20.494077]\n",
      "[Epoch 3000/10000] [MSE loss: 0.050341] [G loss: 18.400850] [D loss: 2.489586] [Q loss: 4.268894] [Phy loss: 3.027370] [Adv G loss: 13.136622]\n",
      "[Epoch 3100/10000] [MSE loss: 0.009070] [G loss: 20.955707] [D loss: 4.587405] [Q loss: 4.555343] [Phy loss: 2.812436] [Adv G loss: 16.463896]\n",
      "[Epoch 3200/10000] [MSE loss: 0.026090] [G loss: 20.467454] [D loss: 1.878608] [Q loss: 4.427259] [Phy loss: 2.463001] [Adv G loss: 15.946347]\n",
      "[Epoch 3300/10000] [MSE loss: 0.029657] [G loss: 19.257651] [D loss: 3.052827] [Q loss: 3.759835] [Phy loss: 2.718450] [Adv G loss: 14.516039]\n",
      "[Epoch 3400/10000] [MSE loss: 0.016119] [G loss: 20.191571] [D loss: 3.147163] [Q loss: 4.096237] [Phy loss: 2.715660] [Adv G loss: 15.757949]\n",
      "[Epoch 3500/10000] [MSE loss: 0.015967] [G loss: 14.660105] [D loss: 3.174863] [Q loss: 3.748576] [Phy loss: 2.427707] [Adv G loss: 10.345915]\n",
      "[Epoch 3600/10000] [MSE loss: 0.027254] [G loss: 23.211353] [D loss: 3.938468] [Q loss: 4.082787] [Phy loss: 2.640316] [Adv G loss: 18.740090]\n",
      "[Epoch 3700/10000] [MSE loss: 0.039031] [G loss: 22.734183] [D loss: 3.289691] [Q loss: 4.512285] [Phy loss: 2.479196] [Adv G loss: 18.490882]\n",
      "[Epoch 3800/10000] [MSE loss: 0.013756] [G loss: 24.620923] [D loss: 3.323984] [Q loss: 3.732545] [Phy loss: 2.439700] [Adv G loss: 20.129823]\n",
      "[Epoch 3900/10000] [MSE loss: 0.014549] [G loss: 26.849153] [D loss: 1.818330] [Q loss: 4.233889] [Phy loss: 2.736320] [Adv G loss: 21.965686]\n",
      "[Epoch 4000/10000] [MSE loss: 0.009659] [G loss: 15.623149] [D loss: 4.311182] [Q loss: 4.305923] [Phy loss: 2.423661] [Adv G loss: 11.213350]\n",
      "[Epoch 4100/10000] [MSE loss: 0.015538] [G loss: 21.467485] [D loss: 3.105213] [Q loss: 3.906332] [Phy loss: 2.593238] [Adv G loss: 16.825591]\n",
      "[Epoch 4200/10000] [MSE loss: 0.011551] [G loss: 14.278555] [D loss: 4.469638] [Q loss: 4.224377] [Phy loss: 2.460871] [Adv G loss: 10.013955]\n",
      "[Epoch 4300/10000] [MSE loss: 0.009626] [G loss: 18.095202] [D loss: 3.692443] [Q loss: 3.647299] [Phy loss: 2.458643] [Adv G loss: 13.432208]\n",
      "[Epoch 4400/10000] [MSE loss: 0.014011] [G loss: 25.054496] [D loss: 2.714075] [Q loss: 3.807519] [Phy loss: 2.605408] [Adv G loss: 20.656348]\n",
      "[Epoch 4500/10000] [MSE loss: 0.011577] [G loss: 24.380303] [D loss: 3.716057] [Q loss: 3.451740] [Phy loss: 2.285508] [Adv G loss: 19.911316]\n",
      "[Epoch 4600/10000] [MSE loss: 0.028136] [G loss: 21.801371] [D loss: 2.468857] [Q loss: 4.244556] [Phy loss: 2.633201] [Adv G loss: 17.117082]\n",
      "[Epoch 4700/10000] [MSE loss: 0.016437] [G loss: 23.236608] [D loss: 2.636753] [Q loss: 3.860542] [Phy loss: 2.386697] [Adv G loss: 18.865367]\n",
      "[Epoch 4800/10000] [MSE loss: 0.017192] [G loss: 23.362090] [D loss: 3.960026] [Q loss: 4.357440] [Phy loss: 2.877330] [Adv G loss: 18.356732]\n",
      "[Epoch 4900/10000] [MSE loss: 0.014959] [G loss: 24.828041] [D loss: 2.532121] [Q loss: 3.974415] [Phy loss: 2.292516] [Adv G loss: 20.486817]\n",
      "[Epoch 5000/10000] [MSE loss: 0.019320] [G loss: 22.865495] [D loss: 2.422514] [Q loss: 4.176688] [Phy loss: 2.318840] [Adv G loss: 18.457901]\n",
      "[Epoch 5100/10000] [MSE loss: 0.017204] [G loss: 29.496193] [D loss: 2.771257] [Q loss: 4.059619] [Phy loss: 2.343361] [Adv G loss: 25.256474]\n",
      "[Epoch 5200/10000] [MSE loss: 0.011545] [G loss: 27.207803] [D loss: 5.950809] [Q loss: 3.848574] [Phy loss: 2.347748] [Adv G loss: 22.898348]\n",
      "[Epoch 5300/10000] [MSE loss: 0.013729] [G loss: 19.437994] [D loss: 2.587489] [Q loss: 3.802550] [Phy loss: 2.696714] [Adv G loss: 14.788355]\n",
      "[Epoch 5400/10000] [MSE loss: 0.019480] [G loss: 23.024158] [D loss: 3.370834] [Q loss: 3.684030] [Phy loss: 2.449921] [Adv G loss: 18.465353]\n",
      "[Epoch 5500/10000] [MSE loss: 0.015952] [G loss: 32.094926] [D loss: 2.480354] [Q loss: 3.913014] [Phy loss: 2.409015] [Adv G loss: 27.819427]\n",
      "[Epoch 5600/10000] [MSE loss: 0.011844] [G loss: 21.277774] [D loss: 2.654734] [Q loss: 3.883377] [Phy loss: 2.329913] [Adv G loss: 17.113767]\n",
      "[Epoch 5700/10000] [MSE loss: 0.008985] [G loss: 21.726070] [D loss: 4.835294] [Q loss: 3.975314] [Phy loss: 2.308267] [Adv G loss: 16.964676]\n",
      "[Epoch 5800/10000] [MSE loss: 0.009573] [G loss: 26.253016] [D loss: 2.869214] [Q loss: 4.066131] [Phy loss: 2.396194] [Adv G loss: 22.010904]\n",
      "[Epoch 5900/10000] [MSE loss: 0.012130] [G loss: 25.210745] [D loss: 4.665115] [Q loss: 4.010054] [Phy loss: 2.539974] [Adv G loss: 20.593031]\n",
      "[Epoch 6000/10000] [MSE loss: 0.019953] [G loss: 21.269372] [D loss: 3.119401] [Q loss: 4.403744] [Phy loss: 2.206633] [Adv G loss: 17.145662]\n",
      "[Epoch 6100/10000] [MSE loss: 0.011784] [G loss: 22.430207] [D loss: 2.628510] [Q loss: 3.862758] [Phy loss: 2.311810] [Adv G loss: 18.172582]\n",
      "[Epoch 6200/10000] [MSE loss: 0.008997] [G loss: 27.168027] [D loss: 2.085050] [Q loss: 3.670282] [Phy loss: 2.116626] [Adv G loss: 23.040154]\n",
      "[Epoch 6300/10000] [MSE loss: 0.010485] [G loss: 28.862181] [D loss: 3.559625] [Q loss: 3.760962] [Phy loss: 2.306628] [Adv G loss: 24.399794]\n",
      "[Epoch 6400/10000] [MSE loss: 0.008501] [G loss: 16.871788] [D loss: 2.530300] [Q loss: 4.106524] [Phy loss: 2.297530] [Adv G loss: 12.553110]\n",
      "[Epoch 6500/10000] [MSE loss: 0.018719] [G loss: 22.906423] [D loss: 4.041834] [Q loss: 4.111098] [Phy loss: 2.346575] [Adv G loss: 18.777171]\n",
      "[Epoch 6600/10000] [MSE loss: 0.010485] [G loss: 33.534569] [D loss: 3.066648] [Q loss: 3.831390] [Phy loss: 2.319627] [Adv G loss: 29.508311]\n",
      "[Epoch 6700/10000] [MSE loss: 0.007775] [G loss: 20.377815] [D loss: 4.032734] [Q loss: 3.573643] [Phy loss: 2.221500] [Adv G loss: 16.192207]\n",
      "[Epoch 6800/10000] [MSE loss: 0.008528] [G loss: 19.453409] [D loss: 2.998010] [Q loss: 4.297939] [Phy loss: 2.282983] [Adv G loss: 15.303066]\n",
      "[Epoch 6900/10000] [MSE loss: 0.008340] [G loss: 13.415064] [D loss: 4.022995] [Q loss: 4.285845] [Phy loss: 2.166621] [Adv G loss: 9.417781]\n",
      "[Epoch 7000/10000] [MSE loss: 0.016983] [G loss: 21.564154] [D loss: 2.796269] [Q loss: 3.695093] [Phy loss: 2.478882] [Adv G loss: 16.829781]\n",
      "[Epoch 7100/10000] [MSE loss: 0.009842] [G loss: 26.884151] [D loss: 3.568466] [Q loss: 3.896984] [Phy loss: 2.001475] [Adv G loss: 22.949360]\n",
      "[Epoch 7200/10000] [MSE loss: 0.009049] [G loss: 15.125747] [D loss: 3.725089] [Q loss: 3.970270] [Phy loss: 2.091529] [Adv G loss: 11.378569]\n",
      "[Epoch 7300/10000] [MSE loss: 0.006003] [G loss: 23.283468] [D loss: 2.926188] [Q loss: 4.113854] [Phy loss: 2.214612] [Adv G loss: 19.091166]\n",
      "[Epoch 7400/10000] [MSE loss: 0.008719] [G loss: 22.047780] [D loss: 3.172589] [Q loss: 4.195017] [Phy loss: 2.191765] [Adv G loss: 17.993535]\n",
      "[Epoch 7500/10000] [MSE loss: 0.007737] [G loss: 30.384345] [D loss: 2.710881] [Q loss: 4.371366] [Phy loss: 2.399962] [Adv G loss: 25.962936]\n",
      "[Epoch 7600/10000] [MSE loss: 0.009361] [G loss: 15.081249] [D loss: 3.012255] [Q loss: 3.819379] [Phy loss: 2.101937] [Adv G loss: 11.013471]\n",
      "[Epoch 7700/10000] [MSE loss: 0.012643] [G loss: 20.454526] [D loss: 4.012055] [Q loss: 3.443667] [Phy loss: 2.384078] [Adv G loss: 16.115645]\n",
      "[Epoch 7800/10000] [MSE loss: 0.006984] [G loss: 24.806816] [D loss: 5.014605] [Q loss: 4.265951] [Phy loss: 2.118725] [Adv G loss: 20.983518]\n",
      "[Epoch 7900/10000] [MSE loss: 0.018377] [G loss: 43.161575] [D loss: 1.357276] [Q loss: 4.596826] [Phy loss: 2.517167] [Adv G loss: 38.204001]\n",
      "[Epoch 8000/10000] [MSE loss: 0.014645] [G loss: 29.192479] [D loss: 4.985075] [Q loss: 4.140871] [Phy loss: 2.146782] [Adv G loss: 25.382518]\n",
      "[Epoch 8100/10000] [MSE loss: 0.012342] [G loss: 28.722197] [D loss: 3.490904] [Q loss: 3.868481] [Phy loss: 2.103441] [Adv G loss: 24.340578]\n",
      "[Epoch 8200/10000] [MSE loss: 0.009253] [G loss: 23.091599] [D loss: 4.166047] [Q loss: 3.481077] [Phy loss: 2.455923] [Adv G loss: 18.764844]\n",
      "[Epoch 8300/10000] [MSE loss: 0.007136] [G loss: 20.278887] [D loss: 4.593384] [Q loss: 3.940290] [Phy loss: 2.214535] [Adv G loss: 16.281655]\n",
      "[Epoch 8400/10000] [MSE loss: 0.009165] [G loss: 25.727989] [D loss: 2.586556] [Q loss: 4.050929] [Phy loss: 2.262796] [Adv G loss: 21.513848]\n",
      "[Epoch 8500/10000] [MSE loss: 0.010437] [G loss: 28.115622] [D loss: 3.882232] [Q loss: 3.907782] [Phy loss: 2.166184] [Adv G loss: 24.009500]\n",
      "[Epoch 8600/10000] [MSE loss: 0.007178] [G loss: 21.966361] [D loss: 3.633544] [Q loss: 4.056418] [Phy loss: 2.281792] [Adv G loss: 17.687973]\n",
      "[Epoch 8700/10000] [MSE loss: 0.008723] [G loss: 18.895440] [D loss: 2.895300] [Q loss: 4.180674] [Phy loss: 2.258306] [Adv G loss: 14.572143]\n",
      "[Epoch 8800/10000] [MSE loss: 0.012771] [G loss: 34.495065] [D loss: 3.776484] [Q loss: 3.736182] [Phy loss: 2.269112] [Adv G loss: 30.198858]\n",
      "[Epoch 8900/10000] [MSE loss: 0.010945] [G loss: 21.741308] [D loss: 2.912054] [Q loss: 3.797025] [Phy loss: 2.244837] [Adv G loss: 17.642595]\n",
      "[Epoch 9000/10000] [MSE loss: 0.008155] [G loss: 15.149514] [D loss: 3.849273] [Q loss: 3.946218] [Phy loss: 2.332803] [Adv G loss: 11.092152]\n",
      "[Epoch 9100/10000] [MSE loss: 0.008734] [G loss: 23.623675] [D loss: 3.364208] [Q loss: 3.604650] [Phy loss: 2.314841] [Adv G loss: 19.236354]\n",
      "[Epoch 9200/10000] [MSE loss: 0.022171] [G loss: 19.831539] [D loss: 3.589985] [Q loss: 3.957517] [Phy loss: 2.075393] [Adv G loss: 15.659091]\n",
      "[Epoch 9300/10000] [MSE loss: 0.008715] [G loss: 16.615057] [D loss: 3.709782] [Q loss: 3.832774] [Phy loss: 2.216365] [Adv G loss: 12.683680]\n",
      "[Epoch 9400/10000] [MSE loss: 0.011944] [G loss: 21.097689] [D loss: 2.763174] [Q loss: 3.903117] [Phy loss: 2.310482] [Adv G loss: 16.779105]\n",
      "[Epoch 9500/10000] [MSE loss: 0.013905] [G loss: 30.887681] [D loss: 1.449819] [Q loss: 4.277805] [Phy loss: 2.252866] [Adv G loss: 26.639115]\n",
      "[Epoch 9600/10000] [MSE loss: 0.009699] [G loss: 24.346120] [D loss: 3.950540] [Q loss: 4.007005] [Phy loss: 2.355696] [Adv G loss: 19.774517]\n",
      "[Epoch 9700/10000] [MSE loss: 0.011893] [G loss: 36.453032] [D loss: 1.425688] [Q loss: 3.526996] [Phy loss: 2.358501] [Adv G loss: 31.976448]\n",
      "[Epoch 9800/10000] [MSE loss: 0.009577] [G loss: 24.339614] [D loss: 2.886675] [Q loss: 4.210057] [Phy loss: 2.077162] [Adv G loss: 20.287664]\n",
      "[Epoch 9900/10000] [MSE loss: 0.005792] [G loss: 25.272064] [D loss: 5.182652] [Q loss: 3.908028] [Phy loss: 2.068506] [Adv G loss: 21.251492]\n"
     ]
    }
   ],
   "source": [
    "PIG.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE on Training Set :  0.14441592\n",
      "Physical Consistency:  0.73300487\n"
     ]
    }
   ],
   "source": [
    "nsamples = 500\n",
    "train_predictions_list = []\n",
    "for run in range(nsamples):\n",
    "    G_train_noise = PIG.sample_noise(PIG.train_x.shape[0], PIG.noise_dim).to(device)\n",
    "    train_predictions = PIG.G.forward(torch.cat([PIG.train_x, G_train_noise], dim=1)).detach().cpu().numpy()\n",
    "    train_predictions = (train_predictions * PIG.Ystd) + PIG.Ymean\n",
    "    train_predictions_list.append(train_predictions)\n",
    "\n",
    "train_predictions = np.mean(train_predictions_list, axis=0)\n",
    "train_predictions_dev = np.var(train_predictions_list, axis=0)\n",
    "train_predictions = torch.Tensor(train_predictions).float().to(device)\n",
    "\n",
    "std_y_t = torch.Tensor(PIG.Ystd).float().to(device)\n",
    "mean_y_t = torch.Tensor(PIG.Ymean).float().to(device)\n",
    "std_x_t = torch.Tensor(PIG.Xstd).float().to(device)\n",
    "mean_x_t = torch.Tensor(PIG.Xmean).float().to(device)\n",
    "\n",
    "train_y = (PIG.train_y * std_y_t) + mean_y_t\n",
    "train_x = (PIG.train_x * std_x_t) + mean_x_t\n",
    "\n",
    "mse_train = ((train_predictions - train_y)**2).mean()\n",
    "phy_train = torch.mean(PIG.physics_loss(train_x, train_predictions, [0, 1], [0, 1])).detach().cpu().numpy()\n",
    "\n",
    "print(\"MSE on Training Set : \", mse_train.detach().cpu().numpy())\n",
    "print(\"Physical Consistency: \", phy_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE on Test Set :  0.18706776\n",
      "Physical Consistency:  0.45116368\n"
     ]
    }
   ],
   "source": [
    "test_predictions_list = []\n",
    "for run in range(nsamples):\n",
    "    G_test_noise = PIG.sample_noise(PIG.test_x.shape[0], PIG.noise_dim).to(device)\n",
    "    test_predictions = PIG.G.forward(torch.cat([PIG.test_x, G_test_noise], dim=1)).detach().cpu().numpy()\n",
    "    test_predictions = (test_predictions * PIG.Ystd) + PIG.Ymean\n",
    "    test_predictions_list.append(test_predictions)\n",
    "\n",
    "test_predictions = np.mean(test_predictions_list, axis=0)\n",
    "test_predictions_dev = np.var(test_predictions_list, axis=0)\n",
    "test_predictions = torch.Tensor(test_predictions).float().to(device)\n",
    "\n",
    "test_y = (PIG.test_y * std_y_t) + mean_y_t\n",
    "test_x = (PIG.test_x * std_x_t) + mean_x_t\n",
    "\n",
    "mse_test = ((test_predictions - test_y)**2).mean()\n",
    "phy_test = torch.mean(PIG.physics_loss(test_x, test_predictions, [0, 1], [0, 1])).detach().cpu().numpy()\n",
    "\n",
    "print(\"MSE on Test Set : \", mse_test.detach().cpu().numpy())\n",
    "print(\"Physical Consistency: \", phy_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
